# Copyright 2023, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AggregationFactory for sampling per-client federated learning metrics."""

from typing import Union

from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import sampling
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.federated_context import federated_computation
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning.metrics import aggregation_utils
from tensorflow_federated.python.learning.metrics import types


class FinalizeThenSampleFactory(factory.UnweightedAggregationFactory):
  """Aggregation Factory that finalizes and then samples the metrics.

  The created `tff.templates.AggregationProcess` finalizes each client's metrics
  locally, and then collects metrics from at most `sample_size` clients at the
  `tff.SERVER`. If more than `sample_size` clients participating, then
  `sample_size` clients are sampled (by reservoir sampling algorithm);
  otherwise, all clients' metrics are collected. Sampling is done in a
  "per-client" manner, i.e., a client, once sampled, will contribute all its
  metrics to the final result.

  The collected metrics samples at `tff.SERVER` has the same structure (i.e.,
  same keys in a dictionary) as the client's local metrics, except that each
  leaf node contains a list of scalar metric values, where each value comes from
  a sampled client, e.g.,
    ```
    sampled_metrics_at_server = {
        'metric_a': [a1, a2, ...],
        'metric_b': [b1, b2, ...],
        ...
    }
    ```
  where "a1" and "b1" are from the same client (similary for "a2" and "b2" etc).

  Both "current round samples" and "total rounds samples" are returned, and
  and they both contain at most metrics from `sample_size` clients. Sampling is
  done across the current round's participating clients (the result is
  "current round samples") or across all the participating clients so far (the
  result is "total rounds samples").

  The `next` function of the created `tff.templates.AggregationProcess` takes
  the `state` and local unfinalized metrics reported from `tff.CLIENTS`, and
  returns a `tff.templates.MeasuredProcessOutput` object with the following
  properties:
    - `state`: a dictionary of total rounds samples and the sampling metadata (
        e.g., random values generated by the reservoir sampling algorithm).
    - `result`: a tuple of current round samples and total rounds samples.
    - `measurements`: the number of non-finite (`NaN` or `Inf` values) leaves in
        the current round client values *before* sampling.

  Example usage:
    ```
    sample_process = FinalizeThenSampleFactory(sample_size).create(
        metric_finalizers, local_unfinalized_metrics_type)
    eval_process = tff.learning.algorithms.build_fed_eval(
        model_fn=..., metrics_aggregation_process=sample_process, ...)
    state = eval_process.initialize()
    for i in range(num_rounds):
      output = eval_process.next(state, client_data_at_round_i)
      state = output.state
      current_round_samples, total_rounds_samples = output.result
    ```
  The created `eval_process` can also be used in
  `tff.learning.programs.EvaluationManager`.
  """

  def __init__(self, sample_size: int = 100):
    """Initialize the `FinalizeThenSampleFactory`.

    Args:
      sample_size: An integer specifying the number of clients sampled (by
        reservoir sampling algorithm). Metrics from the sampled clients are
        collected at the server, and this `sample_size` applies to current round
        and total rounds samples (see the class documentation for details).
        Default value is 100.

    Raises:
      TypeError: If any argument type mismatches.
      ValueError: If `sample_size` is not positive.
    """
    py_typecheck.check_type(sample_size, int, 'sample_size')
    if sample_size <= 0:
      raise ValueError('sample_size must be positive.')
    self._sample_size = sample_size

  def create(
      self,
      metric_finalizers: Union[
          types.MetricFinalizersType,
          types.FunctionalMetricFinalizersType,
      ],
      local_unfinalized_metrics_type: computation_types.StructWithPythonType,
  ) -> aggregation_process.AggregationProcess:
    """Creates a `tff.templates.AggregationProcess` for metrics aggregation.

    Args:
      metric_finalizers: Either the result of
        `tff.learning.models.VariableModel.metric_finalizers` (an `OrderedDict`
        of callables) or the
        `tff.learning.models.FunctionalModel.finalize_metrics` method (a
        callable that takes an `OrderedDict` argument). If the former, the keys
        must be the same as the `OrderedDict` returned by
        `tff.learning.models.VariableModel.report_local_unfinalized_metrics`. If
        the later, the callable must compute over the same keyspace of the
        result returned by
        `tff.learning.models.FunctionalModel.update_metrics_state`.
      local_unfinalized_metrics_type: A `tff.types.StructWithPythonType` (with
        `collections.OrderedDict` as the Python container) of a client's local
        unfinalized metrics.

    Returns:
      An instance of `tff.templates.AggregationProcess`.

    Raises:
      TypeError: If any argument type mismatches; if the metric finalizers
        mismatch the type of local unfinalized metrics; if the initial
        unfinalized metrics mismatch the type of local unfinalized metrics.
    """
    aggregation_utils.check_metric_finalizers(metric_finalizers)
    aggregation_utils.check_local_unfinalized_metrics_type(
        local_unfinalized_metrics_type
    )
    if not callable(metric_finalizers):
      # If we have a FunctionalMetricsFinalizerType it's a function that can
      # only be checked when we call it, as users may have used *args/**kwargs
      # arguments or otherwise making it hard to deduce the type.
      aggregation_utils.check_finalizers_matches_unfinalized_metrics(
          metric_finalizers, local_unfinalized_metrics_type
      )
    local_finalize_computation = aggregation_utils.build_finalizer_computation(
        metric_finalizers, local_unfinalized_metrics_type
    )
    local_finalized_metrics_type = (
        local_finalize_computation.type_signature.result
    )
    sampling_process = sampling.UnweightedReservoirSamplingFactory(
        sample_size=self._sample_size, return_sampling_metadata=True
    ).create(local_finalized_metrics_type)

    @federated_computation.federated_computation
    def init_fn():
      @tensorflow_computation.tf_computation
      def create_initial_sample_state():
        return sampling.build_initial_sample_reservoir(
            local_finalized_metrics_type
        )

      return intrinsics.federated_eval(
          create_initial_sample_state, placements.SERVER
      )

    # We cannot directly use `init_fn.type_signature.result` as the server state
    # type because the `init_fn` returns values with 0 shape. This fails to
    # capture the fact that the state can grow size over rounds. Instead, we
    # should use `None` to denote the shape that can change over rounds.
    state_type = sampling.build_reservoir_type(local_finalized_metrics_type)

    @federated_computation.federated_computation(
        computation_types.FederatedType(state_type, placements.SERVER),
        computation_types.FederatedType(
            local_unfinalized_metrics_type, placements.CLIENTS
        ),
    )
    def next_fn(state, client_unfinalized_metrics):
      local_finalized_metrics = intrinsics.federated_map(
          local_finalize_computation, client_unfinalized_metrics
      )
      current_round_sampling_output = sampling_process.next(
          sampling_process.initialize(), local_finalized_metrics
      )
      merge_samples_computation = sampling.build_merge_samples_computation(
          value_type=local_finalized_metrics_type, sample_size=self._sample_size
      )
      new_state = intrinsics.federated_map(
          merge_samples_computation,
          (state, current_round_sampling_output.result),
      )
      current_round_samples = current_round_sampling_output.result['samples']
      total_rounds_samples = new_state['samples']
      return measured_process.MeasuredProcessOutput(
          state=new_state,
          result=intrinsics.federated_zip(
              (current_round_samples, total_rounds_samples)
          ),
          measurements=current_round_sampling_output.measurements,
      )

    return aggregation_process.AggregationProcess(init_fn, next_fn)
